{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "KLHiTPXNTf2a"
      },
      "source": [
        "##### Copyright 2025 Google LLC."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 1,
      "metadata": {
        "cellView": "form",
        "id": "oTuT5CsaTigz"
      },
      "outputs": [],
      "source": [
        "# @title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ZNM-D0pLXZeR"
      },
      "source": [
        "# Gemini API: Chat with SQL using LangChain"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "PRZo8H09Bs6u"
      },
      "source": [
        "<a target=\"_blank\" href=\"https://colab.research.google.com/github/google-gemini/cookbook/blob/main/examples/langchain/Chat_with_SQL_using_langchain.ipynb\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" height=30/></a>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "mOGNjAZMwMIk"
      },
      "source": [
        "Reading an SQL database can be challenging for humans. However, with accurate prompts, Gemini models can generate answers based on the data. Through the use of the Gemini API, you will be able retrieve necessary information by chatting with a SQL database."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 2,
      "metadata": {
        "id": "CaSSapCIcoxy"
      },
      "outputs": [],
      "source": [
        "%pip install -U -q \"google-genai>=1.7.0\" langchain langchain-community langchain-google-genai"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 3,
      "metadata": {
        "id": "bBoPE7f7cmKA"
      },
      "outputs": [],
      "source": [
        "import sqlite3\n",
        "\n",
        "from langchain.chains import create_sql_query_chain, LLMChain\n",
        "from langchain.prompts import PromptTemplate\n",
        "from langchain_google_genai import ChatGoogleGenerativeAI\n",
        "from langchain_core.output_parsers import StrOutputParser\n",
        "from langchain_community.utilities import SQLDatabase\n",
        "from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool\n",
        "from operator import itemgetter\n",
        "from langchain_core.runnables import RunnablePassthrough\n",
        "from google import genai\n",
        "from IPython.display import Markdown"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "FQOGMejVu-6D"
      },
      "source": [
        "## Configure your API key\n",
        "\n",
        "To run the following cell, your API key must be stored in a Colab Secret named `GOOGLE_API_KEY`. If you don't already have an API key, or you're not sure how to create a Colab Secret, see [Authentication](https://github.com/google-gemini/cookbook/blob/main/quickstarts/Authentication.ipynb) for an example.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 4,
      "metadata": {
        "id": "ysayz8skEfBW"
      },
      "outputs": [],
      "source": [
        "import os\n",
        "from google.colab import userdata\n",
        "GOOGLE_API_KEY=userdata.get('GOOGLE_API_KEY')\n",
        "\n",
        "os.environ[\"GOOGLE_API_KEY\"] = GOOGLE_API_KEY"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "NzyyOcsUR0HO"
      },
      "source": [
        "## Setting up the database\n",
        "To query a database, you first need to set one up.\n",
        "\n",
        "1. **Load the California Housing Dataset:** Load the dataset from sklearn.datasets and extract it into a DataFrame.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 5,
      "metadata": {
        "id": "lK85M4XGRzsM"
      },
      "outputs": [],
      "source": [
        "from sklearn.datasets import fetch_california_housing\n",
        "\n",
        "california_housing_bunch = fetch_california_housing(as_frame=True)\n",
        "california_housing_df = california_housing_bunch.frame"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "GIjbnY4X_Kwd"
      },
      "source": [
        "2. **Connect to the SQLite database:** The database will be stored in the specified file."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 7,
      "metadata": {
        "id": "1Kbqtjo4V2qM"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "20640"
            ]
          },
          "execution_count": 7,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "conn = sqlite3.connect(\"mydatabase.db\")\n",
        "\n",
        "# Write the DataFrame to a SQL table named 'housing'.\n",
        "california_housing_df.to_sql(\"housing\", conn, index=False)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 8,
      "metadata": {
        "id": "D7ChiDUmXC3K"
      },
      "outputs": [],
      "source": [
        "# Create an SQLDatabase object\n",
        "db = SQLDatabase.from_uri(\"sqlite:///mydatabase.db\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "jHGQdREVYFxo"
      },
      "source": [
        "## Question to query\n",
        "With the database connection established, the `SQLDatabase` object now contains information about our database, which the model can access.\n",
        "\n",
        "You can now start asking the LLM to generate queries.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 9,
      "metadata": {
        "id": "g0xP-OStxDW8"
      },
      "outputs": [
        {
          "data": {
            "text/markdown": "\nCREATE TABLE housing (\n\t\"MedInc\" REAL, \n\t\"HouseAge\" REAL, \n\t\"AveRooms\" REAL, \n\t\"AveBedrms\" REAL, \n\t\"Population\" REAL, \n\t\"AveOccup\" REAL, \n\t\"Latitude\" REAL, \n\t\"Longitude\" REAL, \n\t\"MedHouseVal\" REAL\n)\n\n/*\n3 rows from housing table:\nMedInc\tHouseAge\tAveRooms\tAveBedrms\tPopulation\tAveOccup\tLatitude\tLongitude\tMedHouseVal\n8.3252\t41.0\t6.984126984126984\t1.0238095238095237\t322.0\t2.5555555555555554\t37.88\t-122.23\t4.526\n8.3014\t21.0\t6.238137082601054\t0.9718804920913884\t2401.0\t2.109841827768014\t37.86\t-122.22\t3.585\n7.2574\t52.0\t8.288135593220339\t1.073446327683616\t496.0\t2.8022598870056497\t37.85\t-122.24\t3.521\n*/",
            "text/plain": [
              "<IPython.core.display.Markdown object>"
            ]
          },
          "execution_count": 9,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "# you can see what information is available\n",
        "Markdown(db.get_table_info())"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 10,
      "metadata": {
        "id": "MFnV5-dUaa77"
      },
      "outputs": [],
      "source": [
        "# Define query chain\n",
        "llm = ChatGoogleGenerativeAI(model=\"gemini-2.5-flash\", temperature=0)\n",
        "write_query_chain = create_sql_query_chain(llm, db)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "6EdrtzVX0zcm"
      },
      "source": [
        "You use `create_sql_query_chain` that fits our database. It provides default prompts for various types of SQL including Oracle, Google SQL, MySQL and more.\n",
        "\n",
        "\n",
        "In this case, default prompt is suitable for the task. However, feel free to experiment with writing this part of our chain yourself to suit your preferences."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 11,
      "metadata": {
        "id": "r2TjWih70ro6"
      },
      "outputs": [
        {
          "data": {
            "text/markdown": "You are a SQLite expert. Given an input question, first create a syntactically correct SQLite query to run, then look at the results of the query and return the answer to the input question.\nUnless the user specifies in the question a specific number of examples to obtain, query for at most {top_k} results using the LIMIT clause as per SQLite. You can order the results to return the most informative data in the database.\nNever query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in double quotes (\") to denote them as delimited identifiers.\nPay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.\nPay attention to use date('now') function to get the current date, if the question involves \"today\".\n\nUse the following format:\n\nQuestion: Question here\nSQLQuery: SQL Query to run\nSQLResult: Result of the SQLQuery\nAnswer: Final answer here\n\nOnly use the following tables:\n{table_info}\n\nQuestion: {input}",
            "text/plain": [
              "<IPython.core.display.Markdown object>"
            ]
          },
          "execution_count": 11,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "Markdown(write_query_chain.get_prompts()[0].template)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 12,
      "metadata": {
        "id": "dGONOILk0sr2"
      },
      "outputs": [
        {
          "data": {
            "text/markdown": "```sqlite\nSELECT sum(\"Population\") FROM housing\n```",
            "text/plain": [
              "<IPython.core.display.Markdown object>"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        }
      ],
      "source": [
        "response = write_query_chain.invoke({\"question\": \"What is the total population?\"})\n",
        "display(Markdown(response))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 13,
      "metadata": {
        "id": "1rwKuD6eYhzv"
      },
      "outputs": [
        {
          "data": {
            "application/vnd.google.colaboratory.intrinsic+json": {
              "type": "string"
            },
            "text/plain": [
              "'[(29421840.0,)]'"
            ]
          },
          "execution_count": 13,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "db.run('SELECT SUM(\"Population\") FROM housing')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Fb_72hgXagco"
      },
      "source": [
        "Great! The SQL query is correct, but it needs proper formatting before it can be executed directly by the database.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "3ZOGsW9YcL-I"
      },
      "source": [
        "## Validating the query\n",
        "You will pass the output of the previous query to a model that will extract just the SQL query and ensure its validity."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 14,
      "metadata": {
        "id": "ptuPPTordp6G"
      },
      "outputs": [],
      "source": [
        "validate_prompt = PromptTemplate(\n",
        "    input_variables=[\"not_formatted_query\"],\n",
        "    template=\"\"\"\n",
        "        You are going to receive a text that contains a SQL query. Extract that query.\n",
        "        Make sure that it is a valid SQL command that can be passed directly to the Database.\n",
        "        Avoid using Markdown for this task.\n",
        "        Text: {not_formatted_query}\n",
        "    \"\"\"\n",
        ")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 15,
      "metadata": {
        "id": "5KP9yy1edRfJ"
      },
      "outputs": [
        {
          "data": {
            "application/vnd.google.colaboratory.intrinsic+json": {
              "type": "string"
            },
            "text/plain": [
              "'SELECT sum(\"Population\") FROM housing'"
            ]
          },
          "execution_count": 15,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "validate_chain = write_query_chain | validate_prompt | llm | StrOutputParser()\n",
        "validate_chain.invoke({\"question\": \"What is the total population?\"})"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "oskPmggygOJP"
      },
      "source": [
        "## Automatic querying\n",
        "Now, let's automate the process of querying the database using *QuerySQLDataBaseTool*. This tool can receive text from previous parts of the chain, execute the query, and return the answer.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 16,
      "metadata": {
        "id": "mTfFkHVgcDuo"
      },
      "outputs": [
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "<ipython-input-16-580ecc1223c9>:1: LangChainDeprecationWarning: The class `QuerySQLDataBaseTool` was deprecated in LangChain 0.3.12 and will be removed in 1.0. An updated version of the class exists in the :class:`~langchain-community package and should be used instead. To use it run `pip install -U :class:`~langchain-community` and import as `from :class:`~langchain_community.tools import QuerySQLDatabaseTool``.\n",
            "  execute_query = QuerySQLDataBaseTool(db=db)\n"
          ]
        },
        {
          "data": {
            "application/vnd.google.colaboratory.intrinsic+json": {
              "type": "string"
            },
            "text/plain": [
              "'[(29421840.0,)]'"
            ]
          },
          "execution_count": 16,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "execute_query = QuerySQLDataBaseTool(db=db)\n",
        "execute_chain = validate_chain | execute_query\n",
        "execute_chain.invoke({\"question\": \"What is the total population?\"})"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "amQh9IvQlBH0"
      },
      "source": [
        "## Generating answer\n",
        "You are almost done!\n",
        "\n",
        "To enhance our output, you'll use LLM not only to get the number but to get properly formatted and natural language response."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 17,
      "metadata": {
        "id": "658lkr4xlD6q"
      },
      "outputs": [
        {
          "data": {
            "application/vnd.google.colaboratory.intrinsic+json": {
              "type": "string"
            },
            "text/plain": [
              "'The total population is 29,421,840.'"
            ]
          },
          "execution_count": 17,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "answer_prompt = PromptTemplate.from_template(\"\"\"\n",
        "    You are going to receive a original user question, generated SQL query, and result of said query. You should use this information to answer the original question. Use only information provided to you.\n",
        "\n",
        "    Original Question: {question}\n",
        "    SQL Query: {query}\n",
        "    SQL Result: {result}\n",
        "    Answer: \"\"\"\n",
        ")\n",
        "\n",
        "answer_chain = (\n",
        "    RunnablePassthrough.assign(query=validate_chain).assign(\n",
        "        result=itemgetter(\"query\") | execute_query\n",
        "    )\n",
        "    | answer_prompt | llm | StrOutputParser()\n",
        ")\n",
        "\n",
        "answer_chain.invoke({\"question\": \"What is the total population?\"})"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "SJmceCo2Lebi"
      },
      "source": [
        "## Next steps\n",
        "\n",
        "Congratulations! You've successfully created a functional chain to interact with SQL. Now, feel free to explore further by asking different questions."
      ]
    }
  ],
  "metadata": {
    "colab": {
      "name": "Chat_with_SQL_using_langchain.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
